{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-01T05:46:20.568141Z",
     "start_time": "2020-05-01T05:46:17.383252Z"
    }
   },
   "outputs": [],
   "source": [
    "from vision.ssd.vgg_ssd import create_vgg_ssd, create_vgg_ssd_predictor\n",
    "from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd, create_mobilenetv1_ssd_predictor\n",
    "from vision.ssd.mobilenetv1_ssd_lite import create_mobilenetv1_ssd_lite, create_mobilenetv1_ssd_lite_predictor\n",
    "from vision.ssd.squeezenet_ssd_lite import create_squeezenet_ssd_lite, create_squeezenet_ssd_lite_predictor\n",
    "from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite, create_mobilenetv2_ssd_lite_predictor\n",
    "from vision.utils.misc import Timer\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "import sys\n",
    "import numpy as np\n",
    "import glob\n",
    "import os\n",
    "\n",
    "timer = Timer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-01T05:46:47.287612Z",
     "start_time": "2020-05-01T05:46:32.497350Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Trained Model is Done!\n",
      "\n",
      "Starting Detection...\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# MobileNet SSD # --- Le Anh trained ---\n",
    "net_type = \"mb1-ssd\"\n",
    "# net_type = \"mb2-ssd-lite\"\n",
    "\n",
    "# model_path = \"models/mb1-ssd-Epoch-19-Loss-6.089628639675322.pth\"\n",
    "model_path = \"models/mb1-ssd-Epoch-105-Loss-inf.pth\"\n",
    "# model_path = \"models/mb2-ssd-lite-mp-0_686.pth\"\n",
    "\n",
    "# label_path = \"models/voc-model-labels-orig.txt\"\n",
    "label_path = \"models/voc-model-labels.txt\"\n",
    "# --------------- #\n",
    "\n",
    "class_names = [name.strip() for name in open(label_path).readlines()]\n",
    "num_classes = len(class_names)\n",
    "if net_type == 'vgg16-ssd':\n",
    "\tnet = create_vgg_ssd(len(class_names), is_test=True)\n",
    "elif net_type == 'mb1-ssd':\n",
    "\tnet = create_mobilenetv1_ssd(len(class_names), is_test=True)\n",
    "elif net_type == 'mb1-ssd-lite':\n",
    "\tnet = create_mobilenetv1_ssd_lite(len(class_names), is_test=True)\n",
    "elif net_type == 'mb2-ssd-lite':\n",
    "\tnet = create_mobilenetv2_ssd_lite(len(class_names), is_test=True)\n",
    "elif net_type == 'sq-ssd-lite':\n",
    "\tnet = create_squeezenet_ssd_lite(len(class_names), is_test=True)\n",
    "else:\n",
    "\tprint(\"The net type is wrong. It should be one of vgg16-ssd, mb1-ssd and mb1-ssd-lite.\")\n",
    "\tsys.exit(1)\n",
    "\t\n",
    "net.load(model_path)\n",
    "\n",
    "if net_type == 'vgg16-ssd':\n",
    "\tpredictor = create_vgg_ssd_predictor(net, candidate_size=200)\n",
    "elif net_type == 'mb1-ssd':\n",
    "\tpredictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200)\n",
    "elif net_type == 'mb1-ssd-lite':\n",
    "\tpredictor = create_mobilenetv1_ssd_lite_predictor(net, candidate_size=200)\n",
    "elif net_type == 'mb2-ssd-lite':\n",
    "\tpredictor = create_mobilenetv2_ssd_lite_predictor(net, candidate_size=200)\n",
    "elif net_type == 'sq-ssd-lite':\n",
    "\tpredictor = create_squeezenet_ssd_lite_predictor(net, candidate_size=200)\n",
    "else:\n",
    "\tprint(\"The net type is wrong. It should be one of vgg16-ssd, mb1-ssd and mb1-ssd-lite.\")\n",
    "\tsys.exit(1)\n",
    "\n",
    "print(\"Loading Trained Model is Done!\\n\")\n",
    "\n",
    "print(\"Starting Detection...\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-01T07:06:30.429473Z",
     "start_time": "2020-05-01T07:06:30.420948Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['./images/bdd_test_5.jpg', './images/bdd_test_2.jpg', './images/bdd_test_7.jpg', './images/bdd_test_3.jpg', './images/bdd_test_8.jpg', './images/bdd_test_6.jpg', './images/bdd_test_4.jpg', './images/bdd_test_1.jpg']\n"
     ]
    }
   ],
   "source": [
    "img_src = glob.glob(\"./images/*.jpg\")\n",
    "print(img_src)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-01T07:35:18.921392Z",
     "start_time": "2020-05-01T07:35:18.906585Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['BACKGROUND', 'train', 'truck', 'traffic light', 'traffic sign', 'rider', 'person', 'bus', 'bike', 'car', 'motor']\n"
     ]
    }
   ],
   "source": [
    "print(class_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-01T07:49:33.855557Z",
     "start_time": "2020-05-01T07:49:33.847221Z"
    }
   },
   "outputs": [],
   "source": [
    "label_names = [\"BACKGROUND\", \"train\", \"truck\", \"traffic_light\", \"traffic_sign\",\n",
    "               \"rider\", \"person\", \"bus\", \"bike\", \"car\", \"motor\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-01T07:49:35.977060Z",
     "start_time": "2020-05-01T07:49:34.845026Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bdd_test_5\n",
      "Inference time:  0.03354835510253906\n",
      "Time: 0.09s, Detect Objects: 4.\n",
      "bdd_test_2\n",
      "Inference time:  0.03724956512451172\n",
      "Time: 0.08s, Detect Objects: 4.\n",
      "bdd_test_7\n",
      "Inference time:  0.030572175979614258\n",
      "Time: 0.07s, Detect Objects: 3.\n",
      "bdd_test_3\n",
      "Inference time:  0.03054046630859375\n",
      "Time: 0.07s, Detect Objects: 6.\n",
      "bdd_test_8\n",
      "Inference time:  0.031526803970336914\n",
      "Time: 0.07s, Detect Objects: 7.\n",
      "bdd_test_6\n",
      "Inference time:  0.033592939376831055\n",
      "Time: 0.07s, Detect Objects: 6.\n",
      "bdd_test_4\n",
      "Inference time:  0.02860879898071289\n",
      "Time: 0.07s, Detect Objects: 4.\n",
      "bdd_test_1\n",
      "Inference time:  0.03197956085205078\n",
      "Time: 0.07s, Detect Objects: 6.\n",
      "Image Shape:  (720, 1280, 3) . Check the result!\n"
     ]
    }
   ],
   "source": [
    "for img_src_i in img_src:\n",
    "#     img_name = 'bdd_test_6'\n",
    "    img_name = os.path.splitext(os.path.basename(img_src_i))[0]\n",
    "    print(img_name)    \n",
    "    \n",
    "    orig_image = cv2.imread(img_src_i)\n",
    "\n",
    "    image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "    color = np.random.uniform(0, 255, size = (10, 3))\n",
    "\n",
    "    timer.start()\n",
    "    boxes, labels, probs = predictor.predict(image, 10, 0.4)\n",
    "    interval = timer.end()\n",
    "    print('Time: {:.2f}s, Detect Objects: {:d}.'.format(interval, labels.size(0)))\n",
    "\n",
    "    fps = 1/interval\n",
    "    \n",
    "    # Create and write out a text file:\n",
    "    text_file = open(\"./texts/\" + img_name + \".txt\", \"w+\")\n",
    "\n",
    "    for i in range(boxes.size(0)):\n",
    "        box = boxes[i, :]\n",
    "        label = f\"{class_names[labels[i]]}: {probs[i]:.2f}\"\n",
    "\n",
    "        i_color = int(labels[i])\n",
    "        # color = (25*(i+1), 25*(i+1), 25*(i+1))\n",
    "        # cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), (labels[i]*25, labels[i]*25, labels[i]*25), 2)\n",
    "        cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color[i_color], 2)\n",
    "\n",
    "        cv2.putText(orig_image, label,\n",
    "                    (box[0] - 10, box[1] - 10),\n",
    "                    cv2.FONT_HERSHEY_SIMPLEX,\n",
    "                    1,  # font scale\n",
    "                    color[i_color],\n",
    "                    2)  # line type\n",
    "\n",
    "        cv2.putText(orig_image, \"FPS = \" + str(int(fps)),\n",
    "                    (1080, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)\n",
    "        \n",
    "        print(f\"{label_names[labels[i]]}\", f\"{probs[i]:.6f}\", \n",
    "              int(box[0]), int(box[1]), int(box[2]), int(box[3]), \n",
    "              file=text_file)\n",
    "    \n",
    "    text_file.close()\n",
    "\n",
    "print(\"Image Shape: \", orig_image.shape, \". Check the result!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
